%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Codes for Example 3.2 
% ETKF with Berry-Sauer method, implemented on a 3-dim SDE
% Created by John Harlim 
% Last edited: March 16, 2018  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Estimating Q and R

clear all, close all, tic

global gamma a omega beta 

% filter parameters
load triad
n = 3;          % dimension of the state
TCYC = 10000;    % total assimilation cycle
tau = 500;     % relaxation parameter
t_est = 1;      % When estimating Q, it starts after t_est cycles
flagR = 1;      % 0. Not estimating R, 1. estimating R
flagQ = 1;      % 0. Not estimating Q, 1. Q with for full obs, 2. Q for partial obs

% True Q and R
Q = zeros(n);
Q(2:3,2:3) = sigma^2*DT;
R = .05;

% setting observations
m = 1;
if (m<3)
    flagQ = 2;
end
H = zeros(m,n);
%H(1:m,1:m) = eye(m);
H(2) = 1;
y = H*x(:,1:TCYC) + sqrt(R)*randn(m,TCYC);

K = 10; % ensemble size

xa = zeros(n,K,TCYC);
xb = zeros(n,K,TCYC+1);
Qtilde = zeros(n,n,TCYC);
Rtilde = zeros(m,m,TCYC);
xb(:,:,1) = x(:,1)*ones(1,K)+ .1*randn(n,K);

if (flagQ>0)
    Qtilde(:,:,1) = eye(n);
    QQ = eye(n);
else
    Qtilde(:,:,1) = Q;
end

if (flagR>0)
    Rtilde(:,:,1) = eye(m);
    RR = eye(m);
else
    Rtilde(:,:,1) = R*eye(m);
end

for j=1:TCYC-1
    % deterministic prior ensemble at time j
    xbbar = mean(xb(:,:,j),2);
    U = xb(:,:,j)-xbbar*ones(1,K);
    Pb = U*U'/(K-1);
    yb = H*xb(:,:,j);
    ybbar = mean(yb,2);
    V = yb - ybbar*ones(1,K);    
    HT = H';
    Uprev = U;
    
    % posterior ensemble at time j
    if (j>t_est)
        HF = V/Ua;        
    end
       
    % constructing stochastic ensemble    
    Pb = Pb + Qtilde(:,:,j);    
    % sampling Gaussian noise with zero mean and covariance Pb
    RM = randn(n,K);
    RM2 = RM - mean(RM,2)*ones(1,K);
    [EF,ES] = eig(RM2*RM2'/(K-1));
    ES2 = diag(1./sqrt(diag(ES)));
    U = sqrtm(Pb)*EF*ES2*EF'*RM2;
    xb(:,:,j) = xbbar*ones(1,K)+ U;

    
    % constructing prior perturbation
    xbbar = mean(xb(:,:,j),2); 
    U = xb(:,:,j) - xbbar*ones(1,K);
    yb = H*xb(:,:,j);
    ybbar = mean(yb,2);
    V = yb - ybbar*ones(1,K);    
    d(:,j) =  y(:,j) - ybbar;
    
    % ETKF
    J = (eye(K)*(K-1)+V'*(Rtilde(:,:,j)\V));
    x2 = J\(V'*(Rtilde(:,:,j)\d(:,j)));
    xabar = xbbar + U*x2;    
    [X,Lambda] = svd(J);
    S2 = diag((K-1)./diag(Lambda));
    T = X*sqrtm(S2)*X';
    Py = V*V'/(K-1)+Rtilde(:,:,j);
    Pxy = (U*V')/(K-1); 
    Kg = Pxy/Py;
    Ua = U*T;
    xa(:,:,j) = xabar*ones(1,K) +Ua;
    Kgprev = Kg;
    
    % estimating R
    if (flagR ==0)
        Rtilde(:,:,j+1) = Rtilde(:,:,j);
    else
        Re = d(:,j)*d(:,j)'-V*V'/(K-1);
        RR = RR + 1/tau*(Re - RR);
        [T1,T2] = svd(RR);
        Rtilde(:,:,j+1) = T1*T2*T1';
    end
    
    % estimating Q
    if (flagQ==0) % not estimating Q
        Qtilde(:,:,j+1) = Qtilde(:,:,j);
    else            % estimating Q
        if (j>t_est) % only after t_est step
            if (flagQ==1)   % without parameterization
                MM = (HF\(d(:,j)*d(:,j-1)')+Kgold*(d(:,j-1)*d(:,j-1)'))/HT;
                Qest = MM - Uold*Uold'/(K-1);
                Qest = (Qest+Qest')/2;
            else            % with parameterization
                CC = d(:,j)*d(:,j-1)'+HF*Kgold*(d(:,j-1)*d(:,j-1)')-HF*Uold*Uold'*HT/(K-1);    
                CC = CC(:);        
                NQ = 1;
                BB = zeros(NQ,length(CC));
                for i=1:NQ        
                   temp1 = zeros(n,n);
                   temp1(2,2)=1;
                   temp1(3,3)=1;
                   temp = HF*temp1*HT;
                   BB(i,:) = temp(:);
                end
                                
                DD = BB'\CC;    
                Qest = zeros(n,n);
                for i=1:NQ
                    temp1 = zeros(n,n);
                    temp1(2,2)=DD(i);
                    temp1(3,3)=DD(i);
                    Qest = Qest + temp1;
                end
            end
            QQ = QQ + 1/tau*(Qest-QQ);
            [UU,FF] = eig(QQ/DT);
            Qtilde(:,:,j+1) = UU*max(FF,0)*UU'*DT;
        else
            Qtilde(:,:,j+1) = Qtilde(:,:,j);
        end
    end    
    
    % deterministic forecast
    xb(:,:,j+1)=xa(:,:,j)+ DT*triad(xa(:,:,j));
    Uold = Uprev;
    Kgold = Kgprev;

    
end

meanxa = squeeze(mean(xa,2));
rmsa = (meanxa-x(:,1:TCYC)).^2;
rms = sqrt(mean(mean(rmsa(:,2001:10000))))

for i=1:K
    Xa(:,:,i) = squeeze(xa(:,i,:)) - meanxa;
end
Padiag = sum(Xa.^2,3)/(K-1);
spread = sqrt(mean(mean(Padiag(:,2001:10000))))



toc

grey = [0.4, 0.4, 0.4];

time = [DT:DT:DT*TCYC-DT];
figure(1)
for j=1:3
    subplot(3,1,j)
    hold on
    if (tau==100)
        plot(time,squeeze(mean(xa(j,:,1:TCYC-1),2)),'color',grey,'--','linewidth',2)
    else
        plot(time,squeeze(mean(xa(j,:,1:TCYC-1),2)),'k','linewidth',2)
        plot(time,x(j,1:TCYC-1),'k--')
        
        %if(j==1)
        %legend('estimate \tau=100','estimate \tau=100','truth')
        %end
    end
    hold off
    xlim([900 1000])
    %legend('estimate','truth')
    if (j==1)
        ylabel('u')
    elseif(j==2)
        ylabel('v')
    else
        ylabel('w')
    end
    xlabel('time')
end

%print -depsc -r100 state_etkfqrbs.eps


figure(2)
subplot(2,1,1)
hold on
if (tau==100)
    plot(time(1:10:end),squeeze(sqrt(Qtilde(2,2,1:10:TCYC)/DT)),'color',grey,'--','linewidth',2)
elseif (tau==500)
    plot(time(1:10:end),squeeze(sqrt(Qtilde(2,2,1:10:TCYC)/DT)),'k','linewidth',2)
    plot(time(1:10:end),sigma(1,1)*ones(1,1000),'k--')
    %legend('estimate \tau=100','estimate \tau=100','truth')
end
hold off
xlim([0 TCYC*DT])
%legend('estimate \tau=100','truth')
ylabel('\sigma')

           
subplot(2,1,2)
hold on
if (tau==100)
    plot(time(1:10:end),squeeze(Rtilde(1,1,1:10:TCYC)),'--','color',grey,'linewidth',2)
elseif (tau==500)
    plot(time(1:10:end),squeeze(Rtilde(1,1,1:10:TCYC)),'k','linewidth',2)
    plot(time(1:10:end),R*ones(1,1000),'k--')
end
hold off
xlim([0 TCYC*DT])
ylabel('R')
xlabel('time')

%print -depsc -r100 sparm_etkfqrbs.eps

    

